/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*!
 * \file sort_v100_impl.h
 * \brief
 */
#ifndef IMPL_SORT_SORT_SORT_PRE_IMPL_H
#define IMPL_SORT_SORT_SORT_PRE_IMPL_H

namespace AscendC {
constexpr uint32_t SORT_LEN = 4;

template <typename T>
__aicore__ inline void ComSortInnerLoopTail(const LocalTensor<T> &dstLocal, const LocalTensor<T> &tmpLocal,
    const uint32_t baseOffset, const uint16_t singleMergeTmpElementCount, const uint32_t elementCountTail,
    const int32_t mergeTmpRepeatTimes, int32_t mergeTmpTailQueNum)
{
    if (mergeTmpTailQueNum <= 0) {
        return;
    }
    uint32_t offset0Tail = SORT_LEN * baseOffset * mergeTmpRepeatTimes;
    uint32_t offset1Tail, offset2Tail, offset3Tail;
    uint16_t validBitTail;
    uint16_t elementCountListTail[SORT_LEN] = {singleMergeTmpElementCount, singleMergeTmpElementCount,
    singleMergeTmpElementCount, singleMergeTmpElementCount};
    if (mergeTmpTailQueNum == 2) {
        offset1Tail = offset0Tail + baseOffset;
        elementCountListTail[1] = elementCountTail;
        offset2Tail = 0;
        elementCountListTail[2] = 0;
        offset3Tail = 0;
        elementCountListTail[3] = 0;
        validBitTail = 0b0011;
    } else if (mergeTmpTailQueNum == 3) {
        offset1Tail = offset0Tail + baseOffset;
        offset2Tail = offset0Tail + 2 * baseOffset;
        elementCountListTail[2] = elementCountTail;
        offset3Tail = 0;
        elementCountListTail[3] = 0;
        validBitTail = 0b0111;
    } else {
        offset1Tail = offset0Tail + baseOffset;
        offset2Tail = offset0Tail + 2 * baseOffset;
        offset3Tail = offset0Tail + 3 * baseOffset;
        elementCountListTail[3] = elementCountTail;
        validBitTail = 0b1111;
    }
    if (mergeTmpTailQueNum > 1) {
        MrgSortSrcList sortListTail = MrgSortSrcList(tmpLocal[offset0Tail], tmpLocal[offset1Tail],
            tmpLocal[offset2Tail], tmpLocal[offset3Tail]);
        uint32_t sortedNumTail[SORT_LEN];
        MrgSort<T>(dstLocal[offset0Tail], sortListTail, elementCountListTail, sortedNumTail, validBitTail,
            1);
    } else {
        if constexpr (IsSameType<T, half>::value) {
            DataCopy(dstLocal[offset0Tail], tmpLocal[offset0Tail], elementCountTail * 4);
        } else {
            DataCopy(dstLocal[offset0Tail], tmpLocal[offset0Tail], elementCountTail * 2);
        }
    }
}

}  // namespace AscendC

#endif  // IMPL_SORT_SORT_SORT_PRE_IMPL_H